

class Learner:
    """
    class which handle transition between two hierarchic class, it uses rew modules to compute rewards modification and wrapper to modify the intrinsic rewaard
    """
    def     __init__(self,mpolicy,rew_modules,nlearn=1,delay=1,args=None):
        self.eval_mode=False

        self.args=args
        self.mpolicy = mpolicy
        self.rew_modules = rew_modules
        self.modules=self.rew_modules
        self.cpt_update=0
        self.nlearn=nlearn
        self.delay=delay


    def reset(self,**kwargs):
        self.cpt_steps = -1
        return self.mpolicy.reset(**kwargs)

    def step(self,goal=None,change_goal=None,**kwargs):

        ###steps
        self.mpolicy.change_goal(change_goal)
        obs,action, reward, done, infos = self.mpolicy.step(goal,**kwargs)

        for rewmod in self.rew_modules:
            rewmod.step(**kwargs)

        ###evaluation
        if self.mpolicy.total_num_steps%self.delay==0:

            for i in range(self.nlearn):

                if self.mpolicy.can_learn() and not(self.mpolicy.eval_mode):
                    self.mpolicy.rollouts.sample()

                ###Learning
                for mod in self.modules:
                    mod.learn()

                self.mpolicy.learn()
                ###End the update
                for rewmod in self.modules:
                    rewmod.after_update()
                self.mpolicy.after_update()

            ###Print the statistics
            if self.mpolicy.total_num_steps%self.mpolicy.log_interval ==0:
                self.print()

        return obs,action, reward, done, infos

    def     eval(self):
        self.eval_mode=True
        for rewmod in self.modules:
            rewmod.eval()
        self.mpolicy.eval()

    def train(self):
        self.eval_mode=False
        for rewmod in self.modules:
            rewmod.train()
        self.mpolicy.train()

    def save(self):
        for rewmod in self.modules:
            rewmod.save()
        self.mpolicy.save()

    def load(self):
        for rewmod in self.modules:
            rewmod.load()
        self.mpolicy.load()

    def print(self,**kwargs):
        for mod in self.modules:
            mod.print(**kwargs)
        self.mpolicy.print(**kwargs)


    def can_learn(self):
        return self.mpolicy.can_learn() or self.mpolicy.eval_mode

class LearnerAlone:
    """
    Use this class when there is no upper hierarchy
    """
    def __init__(self,mpolicy,rew_modules,nlearn=1,delay=1):
        self.eval_mode=False
        self.mpolicy = mpolicy
        self.rew_modules = rew_modules
        self.modules=self.rew_modules
        self.nlearn= nlearn
        self.delay=delay


    def reset(self,**kwargs):
        self.cpt_steps = -1
        return self.mpolicy.reset(**kwargs)

    def step(self,num_updates,total_epochs=1,epoch=1,**kwargs):
        obs, reward, action,done, infos = None,None,None,None,None

        comparepolicy =self.mpolicy.envs.mpolicy if hasattr(self.mpolicy.envs,"mpolicy") else self.mpolicy
        while comparepolicy.total_num_steps < epoch*num_updates/total_epochs :
            ###steps
            obs, action, reward, done, infos = self.mpolicy.step(**kwargs)
            for rewmod in self.rew_modules:
                rewmod.step()
            if self.mpolicy.total_num_steps%self.delay == 0:

                ###How many time we learn for one step
                for i in range(self.nlearn):
                    ###Sample interactions to learn on
                    if self.mpolicy.can_learn():
                        self.mpolicy.rollouts.sample()

                    ###Learning
                    self.mpolicy.learn()

                    ###End learning step
                    for rewmod in self.rew_modules:
                        rewmod.after_update()
                    self.mpolicy.after_update()

            ###Print the statistics
            if self.mpolicy.total_num_steps%self.mpolicy.log_interval==0:
                self.print()

        return obs, action,reward, done, infos

    def save(self):
        for mod in self.modules:
            mod.save()
        self.mpolicy.save()

    def load(self):
        for mod in self.modules:
            mod.load()
        self.mpolicy.load()

    def eval(self):
        self.eval_mode=True
        for mod in self.modules:
            mod.eval()
        self.mpolicy.eval()

    def train(self):
        self.eval_mode=False
        for mod in self.modules:
            mod.train()
        self.mpolicy.train()


    def print(self,**kwargs):
        for mod in self.modules:
            mod.print(**kwargs)
        self.mpolicy.print(**kwargs)



